// Copyright Epic Games, Inc. All Rights Reserved.

#include "zenhttp/auth/authmgr.h"

#include <zencore/compactbinary.h>
#include <zencore/compactbinarybuilder.h>
#include <zencore/compactbinaryvalidation.h>
#include <zencore/crypto.h>
#include <zencore/filesystem.h>
#include <zencore/logging.h>
#include <zenhttp/auth/oidc.h>
#include <zenutil/basicfile.h>

#include <condition_variable>
#include <memory>
#include <shared_mutex>
#include <thread>
#include <unordered_map>

#include <fmt/format.h>

namespace zen {

using namespace std::literals;

namespace details {
	IoBuffer ReadEncryptedFile(std::filesystem::path	   Path,
							   const AesKey256Bit&		   Key,
							   const AesIV128Bit&		   IV,
							   std::optional<std::string>& Reason)
	{
		FileContents Result = ReadFile(Path);

		if (Result.ErrorCode)
		{
			return IoBuffer();
		}

		IoBuffer EncryptedBuffer = Result.Flatten();

		if (EncryptedBuffer.GetSize() == 0)
		{
			return IoBuffer();
		}

		std::vector<uint8_t> DecryptionBuffer;
		DecryptionBuffer.resize(EncryptedBuffer.GetSize() + Aes::BlockSize);

		MemoryView DecryptedView = Aes::Decrypt(Key, IV, EncryptedBuffer, MakeMutableMemoryView(DecryptionBuffer), Reason);

		if (DecryptedView.IsEmpty())
		{
			return IoBuffer();
		}

		return IoBufferBuilder::MakeCloneFromMemory(DecryptedView);
	}

	void WriteEncryptedFile(std::filesystem::path		Path,
							IoBuffer					FileData,
							const AesKey256Bit&			Key,
							const AesIV128Bit&			IV,
							std::optional<std::string>& Reason)
	{
		if (FileData.GetSize() == 0)
		{
			return;
		}

		std::vector<uint8_t> EncryptionBuffer;
		EncryptionBuffer.resize(FileData.GetSize() + Aes::BlockSize);

		MemoryView EncryptedView = Aes::Encrypt(Key, IV, FileData, MakeMutableMemoryView(EncryptionBuffer), Reason);

		if (EncryptedView.IsEmpty())
		{
			return;
		}

		TemporaryFile::SafeWriteFile(Path, EncryptedView);
	}
}  // namespace details

class AuthMgrImpl final : public AuthMgr
{
	using Clock		= std::chrono::system_clock;
	using TimePoint = Clock::time_point;
	using Seconds	= std::chrono::seconds;

public:
	AuthMgrImpl(const AuthConfig& Config) : m_Config(Config), m_Log(logging::Get("auth"))
	{
		LoadState();

		m_BackgroundThread.Interval = Config.UpdateInterval;
		m_BackgroundThread.Thread	= std::thread(&AuthMgrImpl::BackgroundThreadEntry, this);
	}

	virtual ~AuthMgrImpl() { Shutdown(); }

	virtual void AddOpenIdProvider(const AddOpenIdProviderParams& Params) final
	{
		if (OpenIdProviderExist(Params.Name))
		{
			ZEN_DEBUG("OpenID provider '{}' already exist", Params.Name);
			return;
		}

		if (Params.Name.empty())
		{
			ZEN_WARN("add OpenID provider FAILED, reason 'invalid name'");
			return;
		}

		std::unique_ptr<OidcClient> Client =
			std::make_unique<OidcClient>(OidcClient::Options{.BaseUrl = Params.Url, .ClientId = Params.ClientId});

		if (const auto InitResult = Client->Initialize(); InitResult.Ok == false)
		{
			ZEN_WARN("query OpenID provider FAILED, reason '{}'", InitResult.Reason);
			return;
		}

		std::string NewProviderName = std::string(Params.Name);

		OpenIdProvider* NewProvider = nullptr;

		{
			std::unique_lock _(m_ProviderMutex);

			if (m_OpenIdProviders.contains(NewProviderName))
			{
				return;
			}

			auto InsertResult = m_OpenIdProviders.emplace(NewProviderName, std::make_unique<OpenIdProvider>());
			NewProvider		  = InsertResult.first->second.get();
		}

		NewProvider->Name		= std::string(Params.Name);
		NewProvider->Url		= std::string(Params.Url);
		NewProvider->ClientId	= std::string(Params.ClientId);
		NewProvider->HttpClient = std::move(Client);

		ZEN_INFO("added OpenID provider '{} - {}'", Params.Name, Params.Url);
	}

	virtual bool AddOpenIdToken(const AddOpenIdTokenParams& Params) final
	{
		if (Params.ProviderName.empty())
		{
			ZEN_WARN("trying add OpenID token with invalid provider name");
			return false;
		}

		if (Params.RefreshToken.empty())
		{
			ZEN_WARN("add OpenID token FAILED, reason 'Token invalid'");
			return false;
		}

		auto RefreshResult = RefreshOpenIdToken(Params.ProviderName, Params.RefreshToken);

		if (RefreshResult.Ok == false)
		{
			ZEN_WARN("refresh OpenId token FAILED, reason '{}'", RefreshResult.Reason);
			return false;
		}

		bool IsNew = false;

		{
			auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
									 .RefreshToken	= RefreshResult.RefreshToken,
									 .AccessToken	= fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
									 .ExpireTime	= Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};

			std::unique_lock _(m_TokenMutex);

			const auto InsertResult = m_OpenIdTokens.insert_or_assign(std::string(Params.ProviderName), std::move(Token));

			IsNew = InsertResult.second;
		}

		if (IsNew)
		{
			ZEN_INFO("added new OpenID token for provider '{}'", Params.ProviderName);
		}
		else
		{
			ZEN_INFO("updating OpenID token for provider '{}'", Params.ProviderName);
		}

		return true;
	}

	virtual OpenIdAccessToken GetOpenIdAccessToken(std::string_view ProviderName) final
	{
		std::unique_lock _(m_TokenMutex);

		if (auto It = m_OpenIdTokens.find(std::string(ProviderName)); It != m_OpenIdTokens.end())
		{
			const OpenIdToken& Token = It->second;

			return {.AccessToken = Token.AccessToken, .ExpireTime = Token.ExpireTime};
		}

		return {};
	}

private:
	bool OpenIdProviderExist(std::string_view ProviderName)
	{
		std::unique_lock _(m_ProviderMutex);

		return m_OpenIdProviders.contains(std::string(ProviderName));
	}

	OidcClient& GetOpenIdClient(std::string_view ProviderName)
	{
		std::unique_lock _(m_ProviderMutex);
		return *m_OpenIdProviders[std::string(ProviderName)]->HttpClient.get();
	}

	OidcClient::RefreshTokenResult RefreshOpenIdToken(std::string_view ProviderName, std::string_view RefreshToken)
	{
		if (OpenIdProviderExist(ProviderName) == false)
		{
			return {.Reason = fmt::format("provider '{}' is missing", ProviderName)};
		}

		OidcClient& Client = GetOpenIdClient(ProviderName);

		return Client.RefreshToken(RefreshToken);
	}

	void Shutdown()
	{
		BackgroundThread::Stop(m_BackgroundThread);
		SaveState();
	}

	void LoadState()
	{
		try
		{
			std::optional<std::string> Reason;

			IoBuffer Buffer =
				details::ReadEncryptedFile(m_Config.RootDirectory / "authstate"sv, m_Config.EncryptionKey, m_Config.EncryptionIV, Reason);

			if (!Buffer)
			{
				if (Reason)
				{
					ZEN_WARN("load auth state FAILED, reason '{}'", Reason.value());
				}

				return;
			}

			const CbValidateError ValidationError = ValidateCompactBinary(Buffer, CbValidateMode::All);

			if (ValidationError != CbValidateError::None)
			{
				ZEN_WARN("load serialized state FAILED, reason 'Invalid compact binary'");
				return;
			}

			if (CbObject AuthState = LoadCompactBinaryObject(Buffer))
			{
				for (CbFieldView ProviderView : AuthState["OpenIdProviders"sv])
				{
					CbObjectView ProviderObj = ProviderView.AsObjectView();

					std::string_view ProviderName = ProviderObj["Name"].AsString();
					std::string_view Url		  = ProviderObj["Url"].AsString();
					std::string_view ClientId	  = ProviderObj["ClientId"].AsString();

					AddOpenIdProvider({.Name = ProviderName, .Url = Url, .ClientId = ClientId});
				}

				for (CbFieldView TokenView : AuthState["OpenIdTokens"sv])
				{
					CbObjectView TokenObj = TokenView.AsObjectView();

					std::string_view ProviderName = TokenObj["ProviderName"sv].AsString();
					std::string_view RefreshToken = TokenObj["RefreshToken"sv].AsString();

					const bool Ok = AddOpenIdToken({.ProviderName = ProviderName, .RefreshToken = RefreshToken});

					if (!Ok)
					{
						ZEN_WARN("load serialized OpenId token for provider '{}' FAILED", ProviderName);
					}
				}
			}
		}
		catch (const std::exception& Err)
		{
			ZEN_ERROR("(de)serialize state FAILED, reason '{}'", Err.what());

			{
				std::unique_lock _(m_ProviderMutex);
				m_OpenIdProviders.clear();
			}

			{
				std::unique_lock _(m_TokenMutex);
				m_OpenIdTokens.clear();
			}
		}
	}

	void SaveState()
	{
		try
		{
			CbObjectWriter AuthState;

			{
				std::unique_lock _(m_ProviderMutex);

				if (m_OpenIdProviders.size() > 0)
				{
					AuthState.BeginArray("OpenIdProviders");
					for (const auto& Kv : m_OpenIdProviders)
					{
						AuthState.BeginObject();
						AuthState << "Name"sv << Kv.second->Name;
						AuthState << "Url"sv << Kv.second->Url;
						AuthState << "ClientId"sv << Kv.second->ClientId;
						AuthState.EndObject();
					}
					AuthState.EndArray();
				}
			}

			{
				std::unique_lock _(m_TokenMutex);

				AuthState.BeginArray("OpenIdTokens");
				if (m_OpenIdTokens.size() > 0)
				{
					for (const auto& Kv : m_OpenIdTokens)
					{
						AuthState.BeginObject();
						AuthState << "ProviderName"sv << Kv.first;
						AuthState << "RefreshToken"sv << Kv.second.RefreshToken;
						AuthState.EndObject();
					}
				}
				AuthState.EndArray();
			}

			std::filesystem::create_directories(m_Config.RootDirectory);

			std::optional<std::string> Reason;

			details::WriteEncryptedFile(m_Config.RootDirectory / "authstate"sv,
										AuthState.Save().GetBuffer().AsIoBuffer(),
										m_Config.EncryptionKey,
										m_Config.EncryptionIV,
										Reason);

			if (Reason)
			{
				ZEN_WARN("save auth state FAILED, reason '{}'", Reason.value());
			}
		}
		catch (const std::exception& Err)
		{
			ZEN_WARN("serialize state FAILED, reason '{}'", Err.what());
		}
	}

	void BackgroundThreadEntry()
	{
		SetCurrentThreadName("auth");

		for (;;)
		{
			std::cv_status SignalStatus = BackgroundThread::WaitForSignal(m_BackgroundThread);

			if (m_BackgroundThread.Running.load() == false)
			{
				break;
			}

			if (SignalStatus != std::cv_status::timeout)
			{
				continue;
			}

			{
				// Refresh Open ID token(s)

				std::vector<OpenIdTokenMap::value_type> ExpiredTokens;

				{
					std::unique_lock _(m_TokenMutex);

					for (const auto& Kv : m_OpenIdTokens)
					{
						const Seconds ExpiresIn = std::chrono::duration_cast<Seconds>(Kv.second.ExpireTime - Clock::now());
						const bool	  Expired	= ExpiresIn < Seconds(m_BackgroundThread.Interval * 2);

						if (Expired)
						{
							ExpiredTokens.push_back(Kv);
						}
					}
				}

				if (ExpiredTokens.empty())
				{
					continue;
				}

				ZEN_DEBUG("refreshing {} OpenID token(s)", ExpiredTokens.size());

				for (const auto& Kv : ExpiredTokens)
				{
					OidcClient::RefreshTokenResult RefreshResult = RefreshOpenIdToken(Kv.first, Kv.second.RefreshToken);

					if (RefreshResult.Ok)
					{
						ZEN_DEBUG("refresh access token from provider '{}' Ok", Kv.first);

						auto Token = OpenIdToken{.IdentityToken = RefreshResult.IdentityToken,
												 .RefreshToken	= RefreshResult.RefreshToken,
												 .AccessToken	= fmt::format("Bearer {}"sv, RefreshResult.AccessToken),
												 .ExpireTime	= Clock::now() + Seconds(RefreshResult.ExpiresInSeconds)};

						{
							std::unique_lock _(m_TokenMutex);
							m_OpenIdTokens.insert_or_assign(Kv.first, std::move(Token));
						}
					}
					else
					{
						ZEN_WARN("refresh access token from provider '{}' FAILED, reason '{}'", Kv.first, RefreshResult.Reason);
					}
				}
			}
		}
	}

	struct BackgroundThread
	{
		std::chrono::seconds	Interval{10};
		std::mutex				Mutex;
		std::condition_variable Signal;
		std::atomic_bool		Running{true};
		std::thread				Thread;

		static void Stop(BackgroundThread& State)
		{
			if (State.Running.load())
			{
				State.Running.store(false);
				State.Signal.notify_one();
			}

			if (State.Thread.joinable())
			{
				State.Thread.join();
			}
		}

		static std::cv_status WaitForSignal(BackgroundThread& State)
		{
			std::unique_lock Lock(State.Mutex);
			return State.Signal.wait_for(Lock, State.Interval);
		}
	};

	struct OpenIdProvider
	{
		std::string					Name;
		std::string					Url;
		std::string					ClientId;
		std::unique_ptr<OidcClient> HttpClient;
	};

	struct OpenIdToken
	{
		std::string IdentityToken;
		std::string RefreshToken;
		std::string AccessToken;
		TimePoint	ExpireTime{};
	};

	using OpenIdProviderMap = std::unordered_map<std::string, std::unique_ptr<OpenIdProvider>>;
	using OpenIdTokenMap	= std::unordered_map<std::string, OpenIdToken>;

	LoggerRef Log() { return m_Log; }

	AuthConfig		  m_Config;
	LoggerRef		  m_Log;
	BackgroundThread  m_BackgroundThread;
	OpenIdProviderMap m_OpenIdProviders;
	OpenIdTokenMap	  m_OpenIdTokens;
	std::mutex		  m_ProviderMutex;
	std::shared_mutex m_TokenMutex;
};

std::unique_ptr<AuthMgr>
AuthMgr::Create(const AuthConfig& Config)
{
	return std::make_unique<AuthMgrImpl>(Config);
}

}  // namespace zen
